# -*- coding: utf-8 -*-
# @author: HRUN

import socket
import ssl
import time
import threading
import logging
import asyncio
import aiohttp
import json
from typing import Dict, List, Optional, Tuple, Any
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, asdict
import struct
import uuid

logger = logging.getLogger(__name__)


@dataclass
class ProtocolTestConfig:
    """协议测试配置"""
    protocol: str  # HTTP, HTTPS, TCP, UDP, WebSocket
    host: str
    port: int
    timeout: int = 10
    concurrent_users: int = 1
    duration: int = 60
    request_data: Dict = None
    headers: Dict = None
    method: str = 'GET'  # For HTTP/HTTPS
    path: str = '/'  # For HTTP/HTTPS
    ssl_verify: bool = True
    custom_payload: bytes = None  # For TCP/UDP


@dataclass
class ProtocolTestResult:
    """协议测试结果"""
    protocol: str
    start_time: datetime
    end_time: datetime
    total_requests: int = 0
    successful_requests: int = 0
    failed_requests: int = 0
    total_bytes_sent: int = 0
    total_bytes_received: int = 0
    avg_response_time: float = 0.0
    min_response_time: float = 0.0
    max_response_time: float = 0.0
    error_rate: float = 0.0
    throughput: float = 0.0  # requests per second
    errors: List[str] = None
    detailed_metrics: Dict = None

    def __post_init__(self):
        if self.errors is None:
            self.errors = []
        if self.detailed_metrics is None:
            self.detailed_metrics = {}


class MultiProtocolTestEngine:
    """多协议测试引擎"""
    
    def __init__(self, config: ProtocolTestConfig):
        self.config = config
        self.results = []
        self.running = False
        self.start_time = None
        self.response_times = []
        self.errors = []
        self.total_bytes_sent = 0
        self.total_bytes_received = 0
        self.lock = threading.Lock()
        
    def run_test(self) -> ProtocolTestResult:
        """运行协议测试"""
        logger.info(f"开始 {self.config.protocol} 协议测试 - {self.config.host}:{self.config.port}")
        
        self.running = True
        self.start_time = datetime.now()
        
        try:
            if self.config.protocol.upper() in ['HTTP', 'HTTPS']:
                return self._run_http_test()
            elif self.config.protocol.upper() == 'TCP':
                return self._run_tcp_test()
            elif self.config.protocol.upper() == 'UDP':
                return self._run_udp_test()
            elif self.config.protocol.upper() == 'WEBSOCKET':
                return self._run_websocket_test()
            else:
                raise ValueError(f"不支持的协议: {self.config.protocol}")
                
        except Exception as e:
            logger.error(f"协议测试失败: {e}")
            raise
        finally:
            self.running = False
    
    def _run_http_test(self) -> ProtocolTestResult:
        """运行HTTP/HTTPS测试"""
        
        async def http_worker(session, worker_id):
            """HTTP工作线程"""
            worker_results = {
                'requests': 0,
                'successful': 0,
                'failed': 0,
                'response_times': [],
                'bytes_sent': 0,
                'bytes_received': 0,
                'errors': []
            }
            
            url = f"{self.config.protocol.lower()}://{self.config.host}:{self.config.port}{self.config.path}"
            
            while self.running and (datetime.now() - self.start_time).total_seconds() < self.config.duration:
                try:
                    start_time = time.time()
                    
                    # 准备请求数据
                    kwargs = {
                        'timeout': aiohttp.ClientTimeout(total=self.config.timeout),
                        'ssl': self.config.ssl_verify if self.config.protocol.upper() == 'HTTPS' else None
                    }
                    
                    if self.config.headers:
                        kwargs['headers'] = self.config.headers
                    
                    if self.config.method.upper() in ['POST', 'PUT', 'PATCH'] and self.config.request_data:
                        kwargs['json'] = self.config.request_data
                    
                    # 发送请求
                    async with session.request(self.config.method, url, **kwargs) as response:
                        response_data = await response.read()
                        
                    end_time = time.time()
                    response_time = (end_time - start_time) * 1000  # 转换为毫秒
                    
                    worker_results['requests'] += 1
                    worker_results['successful'] += 1
                    worker_results['response_times'].append(response_time)
                    worker_results['bytes_received'] += len(response_data)
                    
                    if self.config.request_data:
                        worker_results['bytes_sent'] += len(json.dumps(self.config.request_data).encode())
                    
                except Exception as e:
                    worker_results['requests'] += 1
                    worker_results['failed'] += 1
                    worker_results['errors'].append(str(e))
                    
                # 小延迟避免过度消耗资源
                await asyncio.sleep(0.001)
            
            return worker_results
        
        async def run_async_test():
            """异步运行HTTP测试"""
            connector = aiohttp.TCPConnector(
                limit=self.config.concurrent_users * 2,
                ssl=False if self.config.protocol.upper() == 'HTTP' else None
            )
            
            async with aiohttp.ClientSession(connector=connector) as session:
                # 创建并发工作任务
                tasks = []
                for i in range(self.config.concurrent_users):
                    task = asyncio.create_task(http_worker(session, i))
                    tasks.append(task)
                
                # 等待所有任务完成
                worker_results = await asyncio.gather(*tasks, return_exceptions=True)
                
                return worker_results
        
        # 运行异步测试
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        
        try:
            worker_results = loop.run_until_complete(run_async_test())
        finally:
            loop.close()
        
        # 汇总结果
        return self._aggregate_results(worker_results)
    
    def _run_tcp_test(self) -> ProtocolTestResult:
        """运行TCP测试"""
        
        def tcp_worker(worker_id):
            """TCP工作线程"""
            worker_results = {
                'requests': 0,
                'successful': 0,
                'failed': 0,
                'response_times': [],
                'bytes_sent': 0,
                'bytes_received': 0,
                'errors': []
            }
            
            while self.running and (datetime.now() - self.start_time).total_seconds() < self.config.duration:
                sock = None
                try:
                    start_time = time.time()
                    
                    # 创建TCP连接
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    sock.settimeout(self.config.timeout)
                    sock.connect((self.config.host, self.config.port))
                    
                    # 发送数据
                    if self.config.custom_payload:
                        data_to_send = self.config.custom_payload
                    else:
                        data_to_send = f"TCP Test {worker_id} - {datetime.now().isoformat()}\\n".encode()
                    
                    sock.send(data_to_send)
                    worker_results['bytes_sent'] += len(data_to_send)
                    
                    # 接收响应
                    response = sock.recv(4096)
                    worker_results['bytes_received'] += len(response)
                    
                    end_time = time.time()
                    response_time = (end_time - start_time) * 1000
                    
                    worker_results['requests'] += 1
                    worker_results['successful'] += 1
                    worker_results['response_times'].append(response_time)
                    
                except Exception as e:
                    worker_results['requests'] += 1
                    worker_results['failed'] += 1
                    worker_results['errors'].append(str(e))
                    
                finally:
                    if sock:
                        sock.close()
                
                time.sleep(0.01)  # 小延迟
            
            return worker_results
        
        # 使用线程池执行TCP测试
        with ThreadPoolExecutor(max_workers=self.config.concurrent_users) as executor:
            futures = []
            for i in range(self.config.concurrent_users):
                future = executor.submit(tcp_worker, i)
                futures.append(future)
            
            worker_results = [future.result() for future in futures]
        
        return self._aggregate_results(worker_results)
    
    def _run_udp_test(self) -> ProtocolTestResult:
        """运行UDP测试"""
        
        def udp_worker(worker_id):
            """UDP工作线程"""
            worker_results = {
                'requests': 0,
                'successful': 0,
                'failed': 0,
                'response_times': [],
                'bytes_sent': 0,
                'bytes_received': 0,
                'errors': []
            }
            
            while self.running and (datetime.now() - self.start_time).total_seconds() < self.config.duration:
                sock = None
                try:
                    start_time = time.time()
                    
                    # 创建UDP套接字
                    sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
                    sock.settimeout(self.config.timeout)
                    
                    # 发送数据
                    if self.config.custom_payload:
                        data_to_send = self.config.custom_payload
                    else:
                        data_to_send = f"UDP Test {worker_id} - {datetime.now().isoformat()}\\n".encode()
                    
                    sock.sendto(data_to_send, (self.config.host, self.config.port))
                    worker_results['bytes_sent'] += len(data_to_send)
                    
                    # 接收响应（UDP可能不会有响应）
                    try:
                        response, addr = sock.recvfrom(4096)
                        worker_results['bytes_received'] += len(response)
                    except socket.timeout:
                        # UDP超时是正常的，不算错误
                        pass
                    
                    end_time = time.time()
                    response_time = (end_time - start_time) * 1000
                    
                    worker_results['requests'] += 1
                    worker_results['successful'] += 1
                    worker_results['response_times'].append(response_time)
                    
                except Exception as e:
                    worker_results['requests'] += 1
                    worker_results['failed'] += 1
                    worker_results['errors'].append(str(e))
                    
                finally:
                    if sock:
                        sock.close()
                
                time.sleep(0.01)  # 小延迟
            
            return worker_results
        
        # 使用线程池执行UDP测试
        with ThreadPoolExecutor(max_workers=self.config.concurrent_users) as executor:
            futures = []
            for i in range(self.config.concurrent_users):
                future = executor.submit(udp_worker, i)
                futures.append(future)
            
            worker_results = [future.result() for future in futures]
        
        return self._aggregate_results(worker_results)
    
    def _run_websocket_test(self) -> ProtocolTestResult:
        """运行WebSocket测试"""
        
        async def websocket_worker(worker_id):
            """WebSocket工作线程"""
            worker_results = {
                'requests': 0,
                'successful': 0,
                'failed': 0,
                'response_times': [],
                'bytes_sent': 0,
                'bytes_received': 0,
                'errors': []
            }
            
            url = f"ws://{self.config.host}:{self.config.port}{self.config.path}"
            
            try:
                async with aiohttp.ClientSession() as session:
                    async with session.ws_connect(url) as ws:
                        
                        while self.running and (datetime.now() - self.start_time).total_seconds() < self.config.duration:
                            try:
                                start_time = time.time()
                                
                                # 发送消息
                                if self.config.request_data:
                                    message = json.dumps(self.config.request_data)
                                else:
                                    message = f"WebSocket Test {worker_id} - {datetime.now().isoformat()}"
                                
                                await ws.send_str(message)
                                worker_results['bytes_sent'] += len(message.encode())
                                
                                # 接收响应
                                msg = await ws.receive()
                                if msg.type == aiohttp.WSMsgType.TEXT:
                                    response_data = msg.data
                                    worker_results['bytes_received'] += len(response_data.encode())
                                
                                end_time = time.time()
                                response_time = (end_time - start_time) * 1000
                                
                                worker_results['requests'] += 1
                                worker_results['successful'] += 1
                                worker_results['response_times'].append(response_time)
                                
                            except Exception as e:
                                worker_results['requests'] += 1
                                worker_results['failed'] += 1
                                worker_results['errors'].append(str(e))
                            
                            await asyncio.sleep(0.01)
                            
            except Exception as e:
                worker_results['errors'].append(f"WebSocket连接失败: {str(e)}")
                
            return worker_results
        
        async def run_async_websocket_test():
            """异步运行WebSocket测试"""
            tasks = []
            for i in range(self.config.concurrent_users):
                task = asyncio.create_task(websocket_worker(i))
                tasks.append(task)
            
            worker_results = await asyncio.gather(*tasks, return_exceptions=True)
            return worker_results
        
        # 运行异步测试
        loop = asyncio.new_event_loop()
        asyncio.set_event_loop(loop)
        
        try:
            worker_results = loop.run_until_complete(run_async_websocket_test())
        finally:
            loop.close()
        
        return self._aggregate_results(worker_results)
    
    def _aggregate_results(self, worker_results: List[Dict]) -> ProtocolTestResult:
        """汇总测试结果"""
        end_time = datetime.now()
        total_duration = (end_time - self.start_time).total_seconds()
        
        # 初始化汇总数据
        total_requests = 0
        successful_requests = 0
        failed_requests = 0
        total_bytes_sent = 0
        total_bytes_received = 0
        all_response_times = []
        all_errors = []
        
        # 汇总所有工作线程的结果
        for result in worker_results:
            if isinstance(result, dict):  # 确保不是异常
                total_requests += result.get('requests', 0)
                successful_requests += result.get('successful', 0)
                failed_requests += result.get('failed', 0)
                total_bytes_sent += result.get('bytes_sent', 0)
                total_bytes_received += result.get('bytes_received', 0)
                all_response_times.extend(result.get('response_times', []))
                all_errors.extend(result.get('errors', []))
        
        # 计算统计指标
        avg_response_time = sum(all_response_times) / len(all_response_times) if all_response_times else 0
        min_response_time = min(all_response_times) if all_response_times else 0
        max_response_time = max(all_response_times) if all_response_times else 0
        error_rate = (failed_requests / total_requests * 100) if total_requests > 0 else 0
        throughput = total_requests / total_duration if total_duration > 0 else 0
        
        # 构建详细指标
        detailed_metrics = {
            'duration': total_duration,
            'concurrent_users': self.config.concurrent_users,
            'response_time_percentiles': self._calculate_percentiles(all_response_times),
            'bytes_per_second_sent': total_bytes_sent / total_duration if total_duration > 0 else 0,
            'bytes_per_second_received': total_bytes_received / total_duration if total_duration > 0 else 0,
            'error_distribution': self._get_error_distribution(all_errors)
        }
        
        return ProtocolTestResult(
            protocol=self.config.protocol,
            start_time=self.start_time,
            end_time=end_time,
            total_requests=total_requests,
            successful_requests=successful_requests,
            failed_requests=failed_requests,
            total_bytes_sent=total_bytes_sent,
            total_bytes_received=total_bytes_received,
            avg_response_time=avg_response_time,
            min_response_time=min_response_time,
            max_response_time=max_response_time,
            error_rate=error_rate,
            throughput=throughput,
            errors=all_errors[:100],  # 限制错误数量
            detailed_metrics=detailed_metrics
        )
    
    def _calculate_percentiles(self, response_times: List[float]) -> Dict[str, float]:
        """计算响应时间百分位数"""
        if not response_times:
            return {}
        
        sorted_times = sorted(response_times)
        length = len(sorted_times)
        
        percentiles = {}
        for p in [50, 75, 90, 95, 99]:
            index = int(length * p / 100) - 1
            if index < 0:
                index = 0
            percentiles[f'p{p}'] = sorted_times[index]
        
        return percentiles
    
    def _get_error_distribution(self, errors: List[str]) -> Dict[str, int]:
        """获取错误分布"""
        error_counts = {}
        for error in errors:
            error_type = error.split(':')[0] if ':' in error else error
            error_counts[error_type] = error_counts.get(error_type, 0) + 1
        
        return error_counts
    
    def stop_test(self):
        """停止测试"""
        self.running = False


class ProtocolTestManager:
    """协议测试管理器"""
    
    def __init__(self):
        self.active_tests = {}
        self.test_results = {}
        
    def start_protocol_test(self, test_id: str, config: ProtocolTestConfig) -> str:
        """启动协议测试"""
        if test_id in self.active_tests:
            raise ValueError(f"测试 {test_id} 已在运行")
        
        engine = MultiProtocolTestEngine(config)
        
        def run_test():
            try:
                result = engine.run_test()
                self.test_results[test_id] = result
                logger.info(f"协议测试完成: {test_id}")
            except Exception as e:
                logger.error(f"协议测试失败 {test_id}: {e}")
                self.test_results[test_id] = None
            finally:
                if test_id in self.active_tests:
                    del self.active_tests[test_id]
        
        # 在后台线程中运行测试
        test_thread = threading.Thread(target=run_test, daemon=True)
        test_thread.start()
        
        self.active_tests[test_id] = engine
        
        logger.info(f"启动协议测试: {test_id} - {config.protocol}")
        return test_id
    
    def stop_protocol_test(self, test_id: str) -> bool:
        """停止协议测试"""
        if test_id not in self.active_tests:
            return False
        
        engine = self.active_tests[test_id]
        engine.stop_test()
        
        return True
    
    def get_test_status(self, test_id: str) -> Dict:
        """获取测试状态"""
        if test_id in self.active_tests:
            return {
                'status': 'running',
                'test_id': test_id
            }
        elif test_id in self.test_results:
            result = self.test_results[test_id]
            if result:
                return {
                    'status': 'completed',
                    'test_id': test_id,
                    'result': asdict(result)
                }
            else:
                return {
                    'status': 'failed',
                    'test_id': test_id
                }
        else:
            return {
                'status': 'not_found',
                'test_id': test_id
            }
    
    def get_test_result(self, test_id: str) -> Optional[ProtocolTestResult]:
        """获取测试结果"""
        return self.test_results.get(test_id)
    
    def list_active_tests(self) -> List[str]:
        """列出活跃的测试"""
        return list(self.active_tests.keys())
    
    def cleanup_old_results(self, max_results: int = 100):
        """清理旧的测试结果"""
        if len(self.test_results) > max_results:
            # 保留最新的结果
            sorted_results = sorted(
                self.test_results.items(),
                key=lambda x: x[1].end_time if x[1] else datetime.min,
                reverse=True
            )
            
            self.test_results = dict(sorted_results[:max_results])


# 全局协议测试管理器
protocol_test_manager = ProtocolTestManager()


def get_protocol_test_manager() -> ProtocolTestManager:
    """获取协议测试管理器"""
    return protocol_test_manager